import torch
import torch.onnx

from CNN import ConvNet

# 实例化模型
test_net = ConvNet()
# 读取已保存的模型
test_net.load_state_dict(torch.load('cnnmodel.pth'))
test_net.eval()  # Set the model to evaluation mode

dummy_input = torch.randn(1, 1, 28, 28, device='cpu')
 
# 转换为ONNX模型
torch.onnx.export(test_net, dummy_input, "model.onnx", opset_version=10,
                  verbose=True,
                  input_names=["input"],        # 输入名
                  output_names=["output"],)      # 输出名